"""
Structural break analysis for Demographics & Capital Flows project.

Tests:
1. Rolling-window estimation (15-year windows) — time-varying coefficients
2. Pre/post regime breaks: WTO (2001), GFC (2008), tariffs (2018), USSR (1991)
3. Chow-type tests for coefficient stability
4. Post-tariff residual monitoring for US and China
"""

import pandas as pd
import numpy as np
from pathlib import Path
from src.model import PanelGLS

PROCESSED_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/data/processed")
OUTPUT_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/output")
FIG_DIR = OUTPUT_DIR / "figures"
TAB_DIR = OUTPUT_DIR / "tables"


# ---------------------------------------------------------------------------
# Rolling-window estimation
# ---------------------------------------------------------------------------

def rolling_window_estimation(panel_df, window_size=15, step=1, min_obs=200,
                              min_countries=15):
    """
    Re-estimate the baseline model over rolling time windows.

    Returns a DataFrame with one row per window, containing:
    - window start/end years
    - Z_1, Z_2, Z_3 coefficients and standard errors
    - R², rho, N obs, N countries
    """
    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Use the same core controls as baseline, but check availability per window
    all_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                    'nfa_gdp_lag', 'log_rel_opw', 'life_expectancy']
    available_controls = [c for c in all_controls if c in df.columns]

    years = sorted(df['year'].unique())
    min_year = int(min(years))
    max_year = int(max(years))

    results = []
    start_year = min_year
    while start_year + window_size - 1 <= max_year:
        end_year = start_year + window_size - 1
        window_df = df[(df['year'] >= start_year) & (df['year'] <= end_year)].copy()

        # Select controls that have decent coverage in this window
        controls = []
        base_n = window_df.dropna(subset=['ca_gdp'] + demo_vars).shape[0]
        for c in available_controls:
            test_n = window_df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.6 * base_n and test_n >= min_obs:
                controls.append(c)

        all_vars = demo_vars + controls
        est_df = window_df.dropna(subset=['ca_gdp'] + all_vars)

        if len(est_df) < min_obs or est_df['iso3'].nunique() < min_countries:
            start_year += step
            continue

        try:
            model = PanelGLS()
            model.fit(est_df['ca_gdp'].values, est_df[all_vars].values,
                      est_df['iso3'].values, est_df['year'].values)

            row = {
                'window_start': start_year,
                'window_end': end_year,
                'window_mid': (start_year + end_year) / 2,
                'n_obs': model.n_obs,
                'n_countries': model.n_countries,
                'r_squared': model.r_squared,
                'rho': model.rho,
                'controls': ', '.join(controls),
            }
            # Store Z coefficients and SEs
            for i, z in enumerate(demo_vars):
                row[f'{z}_coef'] = model.beta[i]
                row[f'{z}_se'] = model.se[i]
                row[f'{z}_pval'] = model.pvalues[i]

            results.append(row)
        except Exception as e:
            print(f"  Window {start_year}-{end_year} failed: {e}")

        start_year += step

    return pd.DataFrame(results)


# ---------------------------------------------------------------------------
# Structural break tests (interaction approach)
# ---------------------------------------------------------------------------

def estimate_break_model(panel_df, break_year, break_label, controls=None):
    """
    Estimate model with pre/post break interactions:

    CA = Z*gamma + Z*gamma_post*D_post + X*beta + u

    where D_post = 1 if year >= break_year.

    Returns the model, the break interaction coefficients, and the full sample model.
    """
    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Create post-break dummy
    df[f'post_{break_label}'] = (df['year'] >= break_year).astype(float)

    # Create interaction terms
    for z in demo_vars:
        df[f'{z}_x_post'] = df[z] * df[f'post_{break_label}']

    interaction_vars = [f'{z}_x_post' for z in demo_vars]

    # Select controls
    if controls is None:
        all_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                        'nfa_gdp_lag', 'log_rel_opw', 'life_expectancy']
        controls = []
        base_n = df.dropna(subset=['ca_gdp'] + demo_vars).shape[0]
        for c in all_controls:
            if c in df.columns:
                test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
                if test_n >= 0.7 * base_n:
                    controls.append(c)

    all_vars = demo_vars + interaction_vars + controls
    est_df = df.dropna(subset=['ca_gdp'] + all_vars)

    if len(est_df) < 200:
        print(f"  Insufficient obs for break test at {break_year}")
        return None, None

    model = PanelGLS()
    model.fit(est_df['ca_gdp'].values, est_df[all_vars].values,
              est_df['iso3'].values, est_df['year'].values)
    model.feature_names = all_vars

    print(f"\n{'='*70}")
    print(f"STRUCTURAL BREAK TEST: {break_label} (break year = {break_year})")
    print(f"{'='*70}")
    model.summary(feature_names=all_vars)

    return model, est_df


def estimate_split_sample(panel_df, break_year, break_label, controls=None):
    """
    Estimate separate models for pre and post break periods.
    Returns (pre_model, post_model, pre_df, post_df).
    """
    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    if controls is None:
        all_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                        'nfa_gdp_lag', 'log_rel_opw', 'life_expectancy']
        controls = []
        base_n = df.dropna(subset=['ca_gdp'] + demo_vars).shape[0]
        for c in all_controls:
            if c in df.columns:
                test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
                if test_n >= 0.7 * base_n:
                    controls.append(c)

    all_vars = demo_vars + controls

    results = {}
    for period_name, period_df in [('pre', df[df['year'] < break_year]),
                                     ('post', df[df['year'] >= break_year])]:
        est_df = period_df.dropna(subset=['ca_gdp'] + all_vars)
        if len(est_df) < 100 or est_df['iso3'].nunique() < 10:
            print(f"  {period_name}-{break_label}: insufficient data "
                  f"({len(est_df)} obs, {est_df['iso3'].nunique()} countries)")
            results[period_name] = (None, est_df)
            continue

        model = PanelGLS()
        model.fit(est_df['ca_gdp'].values, est_df[all_vars].values,
                  est_df['iso3'].values, est_df['year'].values)
        model.feature_names = all_vars

        print(f"\n--- {period_name.upper()}-{break_label} ({period_name} {break_year}) ---")
        model.summary(feature_names=all_vars)
        results[period_name] = (model, est_df)

    return results, all_vars


# ---------------------------------------------------------------------------
# Transition economy analysis (USSR collapse)
# ---------------------------------------------------------------------------

TRANSITION_COUNTRIES = [
    'RUS', 'CZE', 'HUN', 'POL',  # in EBA 49
    'SVK', 'SVN', 'EST', 'LVA', 'LTU',  # EU accession
    'UKR', 'KAZ', 'BGR', 'ROU', 'HRV',  # other transition
]


def estimate_transition_model(panel_df, transition_year=1995):
    """
    Test whether demographics 'turned on' for transition economies
    after market integration.

    Adds: D_transition × Z interactions (where D = transition country AND year >= transition_year)
    """
    df = panel_df.dropna(subset=['ca_gdp', 'Z_1', 'Z_2', 'Z_3']).copy()
    demo_vars = ['Z_1', 'Z_2', 'Z_3']

    # Transition dummy: is a transition country AND post-integration
    df['is_transition'] = df['iso3'].isin(TRANSITION_COUNTRIES).astype(float)
    df['post_transition'] = ((df['is_transition'] == 1) &
                              (df['year'] >= transition_year)).astype(float)

    for z in demo_vars:
        df[f'{z}_x_transition'] = df[z] * df['post_transition']

    interaction_vars = [f'{z}_x_transition' for z in demo_vars]

    # Controls
    all_controls = ['fiscal_bal_gdp', 'kaopen', 'expected_growth',
                    'nfa_gdp_lag', 'log_rel_opw', 'life_expectancy']
    controls = []
    base_n = df.dropna(subset=['ca_gdp'] + demo_vars).shape[0]
    for c in all_controls:
        if c in df.columns:
            test_n = df.dropna(subset=['ca_gdp'] + demo_vars + controls + [c]).shape[0]
            if test_n >= 0.7 * base_n:
                controls.append(c)

    all_vars = demo_vars + interaction_vars + controls
    est_df = df.dropna(subset=['ca_gdp'] + all_vars)

    print(f"\n{'='*70}")
    print(f"TRANSITION ECONOMY TEST (break year = {transition_year})")
    print(f"  Transition countries in sample: "
          f"{sorted(est_df[est_df['is_transition']==1]['iso3'].unique())}")
    print(f"{'='*70}")

    model = PanelGLS()
    model.fit(est_df['ca_gdp'].values, est_df[all_vars].values,
              est_df['iso3'].values, est_df['year'].values)
    model.feature_names = all_vars
    model.summary(feature_names=all_vars)

    return model, est_df


# ---------------------------------------------------------------------------
# Post-tariff residual monitoring
# ---------------------------------------------------------------------------

def tariff_residual_analysis(panel_df, baseline_model, baseline_df,
                              focus_countries=None, tariff_year=2018):
    """
    Track model residuals for US and China (and others) before and after
    tariff imposition. A growing residual signals trade policy overriding
    demographic fundamentals.
    """
    if focus_countries is None:
        focus_countries = ['USA', 'CHN', 'DEU', 'JPN', 'KOR', 'MEX', 'CAN']

    # Get residuals from baseline model estimation
    df = baseline_df.copy()
    if 'resid_baseline' not in df.columns:
        df['resid_baseline'] = baseline_model.resid

    results = []
    for iso3 in focus_countries:
        cdf = df[df['iso3'] == iso3].sort_values('year')
        if len(cdf) == 0:
            continue

        pre = cdf[cdf['year'] < tariff_year]
        post = cdf[cdf['year'] >= tariff_year]

        row = {
            'iso3': iso3,
            'pre_mean_resid': pre['resid_baseline'].mean() if len(pre) > 0 else np.nan,
            'post_mean_resid': post['resid_baseline'].mean() if len(post) > 0 else np.nan,
            'pre_n': len(pre),
            'post_n': len(post),
        }
        row['resid_shift'] = row['post_mean_resid'] - row['pre_mean_resid']
        results.append(row)

    result_df = pd.DataFrame(results)
    print(f"\n{'='*70}")
    print(f"POST-TARIFF RESIDUAL ANALYSIS (tariff year = {tariff_year})")
    print(f"{'='*70}")
    print(result_df.to_string(index=False, float_format='%.3f'))
    print("\nPositive shift = CA higher than model predicts post-tariffs")
    print("Negative shift = CA lower than model predicts post-tariffs")

    # Also return year-by-year residuals for plotting
    yearly = {}
    for iso3 in focus_countries:
        cdf = df[df['iso3'] == iso3][['year', 'resid_baseline', 'ca_gdp',
                                        'fitted_baseline']].sort_values('year')
        if len(cdf) > 0:
            yearly[iso3] = cdf

    return result_df, yearly


# ---------------------------------------------------------------------------
# Main runner
# ---------------------------------------------------------------------------

def run_structural_break_analysis(panel_df=None):
    """
    Run all structural break tests and save results.
    """
    if panel_df is None:
        panel_df = pd.read_csv(PROCESSED_DIR / "full_panel.csv")

    # Filter to estimation sample
    panel_df = panel_df[(panel_df['year'] >= 1986) & (panel_df['year'] <= 2024)]

    print("\n" + "#" * 70)
    print("# STRUCTURAL BREAK ANALYSIS")
    print("#" * 70)

    all_results = {}

    # 1. Rolling-window estimation
    print("\n>>> Rolling-window estimation (15-year windows) <<<")
    rolling = rolling_window_estimation(panel_df, window_size=15, step=1)
    rolling.to_csv(TAB_DIR / "rolling_window_coefficients.csv", index=False)
    print(f"  Saved rolling_window_coefficients.csv ({len(rolling)} windows)")
    all_results['rolling'] = rolling

    # 2. Structural break tests
    breaks = [
        (2001, 'wto', 'China WTO Accession'),
        (2008, 'gfc', 'Global Financial Crisis'),
        (2018, 'tariff', 'Trump Tariffs'),
    ]

    break_models = {}
    split_results = {}
    for break_year, label, description in breaks:
        print(f"\n>>> Break test: {description} ({break_year}) <<<")

        # Interaction model
        model, est_df = estimate_break_model(panel_df, break_year, label)
        if model is not None:
            break_models[label] = (model, est_df)
            model.to_dataframe().to_csv(
                TAB_DIR / f"break_test_{label}.csv", index=False)

        # Split sample
        results, all_vars = estimate_split_sample(panel_df, break_year, label)
        split_results[label] = results

    all_results['break_models'] = break_models
    all_results['split_results'] = split_results

    # 3. Transition economy test
    print("\n>>> Transition economy test <<<")
    trans_model, trans_df = estimate_transition_model(panel_df)
    if trans_model is not None:
        trans_model.to_dataframe().to_csv(
            TAB_DIR / "break_test_transition.csv", index=False)
    all_results['transition'] = (trans_model, trans_df)

    # 4. Post-tariff residual monitoring
    # First need baseline model for residuals
    from src.model import estimate_baseline_model
    baseline_model, baseline_df = estimate_baseline_model(panel_df)

    tariff_summary, tariff_yearly = tariff_residual_analysis(
        panel_df, baseline_model, baseline_df)
    tariff_summary.to_csv(TAB_DIR / "tariff_residual_shift.csv", index=False)
    all_results['tariff_residuals'] = (tariff_summary, tariff_yearly)
    all_results['baseline'] = (baseline_model, baseline_df)

    # 5. Summary comparison table
    summary_rows = []

    # Split sample comparison
    for break_year, label, description in breaks:
        if label in split_results:
            for period in ['pre', 'post']:
                model, _ = split_results[label].get(period, (None, None))
                if model is not None:
                    row = {
                        'Specification': f'{period.title()}-{description}',
                        'Period': f'{"<" if period == "pre" else ">="}{break_year}',
                        'N obs': model.n_obs,
                        'N countries': model.n_countries,
                        'R²': model.r_squared,
                        'ρ': model.rho,
                    }
                    # Add Z coefficients
                    for i, z in enumerate(['Z_1', 'Z_2', 'Z_3']):
                        row[f'{z}_coef'] = model.beta[i]
                        row[f'{z}_pval'] = model.pvalues[i]
                    summary_rows.append(row)

    if summary_rows:
        summary_df = pd.DataFrame(summary_rows)
        summary_df.to_csv(TAB_DIR / "structural_break_summary.csv", index=False)
        print(f"\n  Saved structural_break_summary.csv")
        print("\n" + "=" * 70)
        print("SPLIT-SAMPLE COMPARISON")
        print("=" * 70)
        print(summary_df.to_string(index=False, float_format='%.4f'))
    all_results['summary'] = summary_df if summary_rows else None

    return all_results


if __name__ == "__main__":
    run_structural_break_analysis()
